import torch
from utils.paramUtil import t2m_kinematic_chain as kinematic_chain
from data.quaternion import *
from data.utils import *
from data.interx_utils import *
import torch.nn as nn
import torch.nn.functional as F
class Geometric_Losses:
    def __init__(self, args,recons_loss, conv_dim, joints_num, dataset_name, device):
        self.args=args
        if recons_loss == 'l1':
            self.l1_criterion = torch.nn.L1Loss()
        elif recons_loss == 'l1_smooth':
            self.l1_criterion = torch.nn.SmoothL1Loss()
        
        self.conv_dim = conv_dim
        self.joints_num = joints_num
        self.fids = [*fid_l, *fid_r]
        self.dataset_name = dataset_name
        if self.dataset_name == 'interhuman':
            self.normalizer = MotionNormalizerTorch(device)
        elif self.dataset_name == 'interx':
            self.normalizer = InterxNormalizerTorch()
            self.kinematics = InterxKinematics()

    def calc_foot_contact(self, motion, pred_motion):
            if self.dataset_name == 'interhuman':
                B, T, _ = motion.shape
                motion = motion[..., :self.joints_num * 3]
                motion = motion.reshape(B, T, self.joints_num, 3)
                
                pred_motion = pred_motion[..., :self.joints_num * 3]
                pred_motion = pred_motion.reshape(B, T, self.joints_num, 3)
            
            
            feet_vel = motion[:, 1:, self.fids, :] - motion[:, :-1, self.fids,:]
            pred_feet_vel = pred_motion[:, 1:, self.fids, :] - pred_motion[:, :-1, self.fids,:]
            feet_h = motion[:, :-1, self.fids, 1]
            pred_feet_h = pred_motion[:, :-1, self.fids, 1]
            # contact = target[:,:-1,:,-8:-4] # [b,t,p,4]

            ## Calculate contacts
            thres = 0.001
            velfactor, heightfactor = torch.Tensor([thres, thres, thres, thres]).to(feet_vel.device), torch.Tensor(
                [0.12, 0.05, 0.12, 0.05]).to(feet_vel.device)

            feet_x = (feet_vel[..., 0]) ** 2
            feet_y = (feet_vel[..., 1]) ** 2
            feet_z = (feet_vel[..., 2]) ** 2
            contact = ((feet_x + feet_y + feet_z) < velfactor) & (feet_h < heightfactor)
            
            fc_loss = self.l1_criterion(pred_feet_vel[contact], torch.zeros_like(pred_feet_vel)[contact])
            if torch.isnan(fc_loss):
                fc_loss = torch.tensor(0).to(motion.device)
                if contact.sum() != 0:
                    print("FC nan but contact not 0")
            return fc_loss
                                 
    def calc_bone_lengths(self, motion):
        if self.dataset_name == 'interhuman':
            motion_pos = motion[..., :self.joints_num*3]
            motion_pos = motion_pos.reshape(motion_pos.shape[0], motion_pos.shape[1], self.joints_num, 3)
        elif self.dataset_name == 'interx':
            motion_pos = motion
        bones = []
        for chain in kinematic_chain:
            for i, joint in enumerate(chain[:-1]):
                bone = (motion_pos[..., chain[i], :] - motion_pos[..., chain[i + 1], :]).norm(dim=-1, keepdim=True)  # [B,T,P,1]
                bones.append(bone)

        return torch.cat(bones, dim=-1)
    
    def calc_loss_geo(self, pred_rot, gt_rot, eps=1e-7):
        if self.dataset_name == "interhuman":
            pred_rot = pred_rot.reshape(pred_rot.shape[0], pred_rot.shape[1], -1, 6)
            gt_rot = gt_rot.reshape(gt_rot.shape[0], gt_rot.shape[1], -1, 6)


        pred_m = cont6d_to_matrix(pred_rot).reshape(-1,3,3)
        gt_m = cont6d_to_matrix(gt_rot).reshape(-1,3,3)

        m = torch.bmm(gt_m, pred_m.transpose(1,2)) #batch*3*3
        
        cos = (m[:,0,0] + m[:,1,1] + m[:,2,2] - 1 )/2        
        theta = torch.acos(torch.clamp(cos, -1+eps, 1-eps))

        return torch.mean(theta)
    # 计算重心距离
    def compute_center_distance(self,joint_positions, root_positions):
        """
        计算重心位置并返回重心与根位置之间的欧氏距离
        :param joint_positions: 关节位置，形状为 (batch_size, num_joints, 3)
        :param root_positions: 根节点位置，形状为 (batch_size, 3)
        :return: 重心距离 (batch_size,) 和重心位置 (batch_size, 3)
        """
        # 计算重心位置
        center_of_mass = torch.mean(joint_positions, dim=2,keepdim=True)
        # 计算重心与根节点之间的欧氏距离
        distance = torch.norm(center_of_mass - root_positions, p=2, dim=-1)
        return distance, center_of_mass

    # 计算向量与XY平面之间的余弦夹角
    def calculate_angle_with_xy_plane(self,A, B):
        """
        计算两点间连线与XY平面之间的余弦夹角
        :param A: 点 A，形状为 (batch_size, 3)
        :param B: 点 B，形状为 (batch_size, 3)
        :return: 余弦值 (batch_size,)
        """
        AB = B - A
        n_xy = torch.tensor([0.0, 0.0, 1.0], device=A.device).expand_as(AB)  # XY平面法向量
        # 计算点积
        dot_product = torch.sum(AB * n_xy, dim=-1)
        AB_norm = torch.norm(AB, p=2, dim=-1)  # AB的模长
        n_xy_norm = torch.norm(n_xy, p=2, dim=-1)  # n_xy的模长
        # 计算余弦值
        cos_theta = dot_product / (AB_norm * n_xy_norm)
        return cos_theta

    # 计算总损失函数
    def calc_center_loss(self,pred_joint_positions,gt_joint_positions):
        """
        计算重心距离的L1损失和角度的L1损失
        :param pred_joint_positions: 预测的关节位置 (batch_size, num_joints, 3)
        :param pred_root_positions: 预测的根节点位置 (batch_size, 3)
        :param gt_joint_positions: 真实的关节位置 (batch_size, num_joints, 3)
        :param gt_root_positions: 真实的根节点位置 (batch_size, 3)
        :return: 总损失值
        """
        b,t,dim=pred_joint_positions.shape
        pred_joint_positions=pred_joint_positions.reshape(b,t,self.joints_num,3)
        gt_joint_positions=gt_joint_positions.reshape(b,t,self.joints_num,3)
        pred_root_positions = pred_joint_positions[..., :1, :]
        gt_root_positions = gt_joint_positions[..., :1, :]
        # 计算预测和真实的重心距离
        pred_distance, pred_center_of_mass = self.compute_center_distance(pred_joint_positions, pred_root_positions)
        gt_distance, gt_center_of_mass = self.compute_center_distance(gt_joint_positions, gt_root_positions)
        
        # 计算重心距离的L1损失
        center_loss = self.l1_criterion(pred_distance, gt_distance)
        
        # 计算预测和真实角度的余弦值
        pred_angle_cos = self.calculate_angle_with_xy_plane(pred_root_positions, pred_center_of_mass)
        gt_angle_cos = self.calculate_angle_with_xy_plane(gt_root_positions, gt_center_of_mass)
        
        # 计算角度的L1损失
        angle_loss = self.l1_criterion(pred_angle_cos, gt_angle_cos)
        
        # 总损失 = 重心损失 + 角度损失
        total_loss = center_loss + angle_loss
        return total_loss
        # return total_loss, center_loss, angle_loss
    def calc_specific_joints_trajectory_loss(self, pred_motion_joints, motion_joints):
        """Calculate trajectory loss for specific joints (0,10,11,20,21) for two people
        
        Args:
            pred_motion_joints: Predicted motion [B,T,2,J,3]
            motion_joints: Ground truth motion [B,T,2,J,3]
            
        Returns:
            Average loss across both people for specified joints
        """
        # Target joint indices
        target_joints = [0, 10, 11, 20, 21]
        
        # Reshape inputs for each person [B,T,J,3]
        b,t,dim=pred_motion_joints.shape
        pred_motion_joints=pred_motion_joints.reshape(b,t,self.joints_num,3)
        motion_joints=motion_joints.reshape(b,t,self.joints_num,3)
        
       
        def calc_person_loss(pred, gt):
            loss = 0
            for joint_idx in target_joints:
                # Extract XZ coordinates
                pred_traj = pred[..., joint_idx, [0,2]]
                gt_traj = gt[..., joint_idx, [0,2]]
                # Calculate L1 loss
                joint_loss = self.l1_criterion(pred_traj, gt_traj).mean()
                loss += joint_loss
            return loss / len(target_joints)
        
        # Calculate losses for each person
        p1_loss = calc_person_loss(pred_motion_joints, motion_joints)        
        # Return average loss across both people
        return p1_loss
    def calc_trajectory_loss(self,pred_motion_joints,motion_joints):
        """
        计算两个人的轨迹损失
        :param motion_joints: 真实轨迹 [B,T,2,J,3]
        :param pred_motion_joints: 预测轨迹 [B,T,2,J,3]
        :return: 轨迹损失
        """
        # 提取根节点(通常是第0个关节)的xz平面坐标
        b,t,dim=pred_motion_joints.shape
        pred_motion_joints=pred_motion_joints.reshape(b,t,self.joints_num,3)
        motion_joints=motion_joints.reshape(b,t,self.joints_num,3)
        gt_traj1 = motion_joints[..., 0, [0,2]]  # 人1的真实轨迹 [B,T,2]
        # gt_traj2 = motion_joints[..., 1, 0, [0,2]]  # 人2的真实轨迹 [B,T,2] 
        
        pred_traj1 = pred_motion_joints[..., 0, [0,2]]  # 人1的预测轨迹 [B,T,2]
        # pred_traj2 = pred_motion_joints[..., 1, 0, [0,2]]  # 人2的预测轨迹 [B,T,2]

        # 计算两个人的轨迹损失
        traj_loss1 = self.l1_criterion(pred_traj1, gt_traj1).mean()
        # traj_loss2 = self.l1_criterion(pred_traj2, gt_traj2).mean()
        
        # 总轨迹损失
        total_traj_loss = traj_loss1
        
        return total_traj_loss
    # 计算向量与XY平面之间的余弦夹角
    def calculate_angle_with_xz_plane(self,A, B):
        """
        计算两点间连线与XY平面之间的余弦夹角
        :param A: 点 A，形状为 (batch_size, 3)
        :param B: 点 B，形状为 (batch_size, 3)
        :return: 余弦值 (batch_size,)
        """
        AB = B - A
        n_xy = torch.tensor([0.0, 1.0, 0.0], device=A.device).expand_as(AB)  # XY平面法向量
        # 计算点积
        dot_product = torch.sum(AB * n_xy, dim=-1)
        AB_norm = torch.norm(AB, p=2, dim=-1)  # AB的模长
        n_xy_norm = torch.norm(n_xy, p=2, dim=-1)  # n_xy的模长
        # 计算余弦值
        cos_theta = dot_product / (AB_norm * n_xy_norm)
        return cos_theta

    # 计算总损失函数
    def calc_centerxz_loss(self,pred_joint_positions,gt_joint_positions):
        """
        计算重心距离的L1损失和角度的L1损失
        :param pred_joint_positions: 预测的关节位置 (batch_size, num_joints, 3)
        :param pred_root_positions: 预测的根节点位置 (batch_size, 3)
        :param gt_joint_positions: 真实的关节位置 (batch_size, num_joints, 3)
        :param gt_root_positions: 真实的根节点位置 (batch_size, 3)
        :return: 总损失值
        """
        if len(pred_joint_positions.shape)==3:
            b,t,dim=pred_joint_positions.shape
            pred_joint_positions=pred_joint_positions.reshape(b,t,self.joints_num,3)
            gt_joint_positions=gt_joint_positions.reshape(b,t,self.joints_num,3)
        pred_root_positions = pred_joint_positions[..., :1, :]
        gt_root_positions = gt_joint_positions[..., :1, :]
        # 计算预测和真实的重心距离
        pred_distance, pred_center_of_mass = self.compute_center_distance(pred_joint_positions, pred_root_positions)
        gt_distance, gt_center_of_mass = self.compute_center_distance(gt_joint_positions, gt_root_positions)
        
        # 计算重心距离的L1损失
        center_loss = self.l1_criterion(pred_distance, gt_distance)
        
        # 计算预测和真实角度的余弦值
        pred_angle_cos = self.calculate_angle_with_xz_plane(pred_root_positions, pred_center_of_mass)
        gt_angle_cos = self.calculate_angle_with_xz_plane(gt_root_positions, gt_center_of_mass)
        
        # 计算角度的L1损失
        angle_loss = self.l1_criterion(pred_angle_cos, gt_angle_cos)
        
        # 总损失 = 重心损失 + 角度损失
        total_loss = center_loss + angle_loss
        return total_loss
    def forward(self, motions, pred_motion):
        if self.dataset_name == 'interhuman':
            if self.conv_dim == 1:
                loss_rec = self.l1_criterion(pred_motion, motions)
            elif self.conv_dim == 2:
                loss_rec = self.l1_criterion(pred_motion[..., :-4], motions[..., :-4])
            
            loss_explicit = self.l1_criterion(pred_motion[:, :, :self.joints_num*3],
                                            motions[:, :, :self.joints_num*3])

            loss_vel = self.l1_criterion(pred_motion[:, 1:, :self.joints_num*3] - pred_motion[:, :-1, :self.joints_num*3],
                                        motions[:, 1:, :self.joints_num*3] - motions[:, :-1, :self.joints_num*3])
            
            loss_bn = self.l1_criterion(self.calc_bone_lengths(pred_motion), self.calc_bone_lengths(motions))

            loss_geo = self.calc_loss_geo(pred_motion[..., self.joints_num*6: self.joints_num*6 + (self.joints_num-1)*6],
                                        motions[..., self.joints_num*6: self.joints_num*6 + (self.joints_num-1)*6])
            
            loss_fc = self.calc_foot_contact(self.normalizer.backward(motions), self.normalizer.backward(pred_motion))
            if self.args.Center_select=="Center":
                loss_center=self.calc_center_loss(pred_motion[:, :, :self.joints_num*3],motions[:, :, :self.joints_num*3])
            elif self.args.Center_select=="CenterXZ":
                loss_center=self.calc_centerxz_loss(pred_motion[:, :, :self.joints_num*3],motions[:, :, :self.joints_num*3])
            else:
                loss_center=torch.tensor(0.0).to(motions.device)
            if self.args.Traj_select=="Trajectory":
                loss_trajectory=self.calc_trajectory_loss(pred_motion[:, :, :self.joints_num*3],motions[:, :, :self.joints_num*3])
            elif self.args.Traj_select=="Specific":
                loss_trajectory=self.calc_specific_joints_trajectory_loss(pred_motion[:, :, :self.joints_num*3],motions[:, :, :self.joints_num*3])
            else:
                loss_trajectory=torch.tensor(0.0).to(motions.device)
            return loss_rec, loss_explicit, loss_vel, loss_bn, loss_geo, loss_fc, loss_center, loss_trajectory
        elif self.dataset_name == 'interx':
            loss_rec = self.l1_criterion(pred_motion, motions)

            pred_motions_pos = self.kinematics.forward(pred_motion)
            motions_pos = self.kinematics.forward(motions)

            loss_explicit = self.l1_criterion(pred_motions_pos, motions_pos)

            loss_vel = self.l1_criterion(pred_motions_pos[:,1:,:,:] - pred_motions_pos[:,:-1,:,:],
                                            motions_pos[:,1:,:,:] - motions_pos[:,:-1,:,:])
            
            loss_bn = self.l1_criterion(self.calc_bone_lengths(pred_motions_pos[:,:,:22,:]), self.calc_bone_lengths(motions_pos[:,:,:22,:]))

            loss_geo = self.calc_loss_geo(pred_motion[:,:,:-1,:], motions[:,:,:-1,:])

            loss_fc = self.calc_foot_contact(motions_pos, pred_motions_pos)


            return loss_rec, loss_explicit, loss_vel, loss_bn, loss_geo, loss_fc, motions_pos, pred_motions_pos
    

class Inter_Losses:
    def __init__(self,args, recons_loss, joints_num, dataset_name, device):
        self.dataset_name = dataset_name
        self.args=args
        if recons_loss == 'l1':
            self.l1_criterion = torch.nn.L1Loss('none')
        elif recons_loss == 'l1_smooth':
            self.l1_criterion = torch.nn.SmoothL1Loss(reduction='none')
        
        self.joints_num = joints_num
        if self.dataset_name == 'interhuman':
            self.normalizer = MotionNormalizerTorch(device)
        elif self.dataset_name == 'interx':
            self.normalizer = InterxNormalizerTorch()
            self.kinematics = InterxKinematics()
        self.select_joint_index=[0,3,6,9,10,11,12,13,14,15,16,17,18,19,20,21]

    
    def calc_dm_loss(self, motion_joints, pred_motion_joints, thresh_pred=1, thresh_tgt=0.1):

        pred_motion_joints1 = pred_motion_joints[..., 0:1, :, :].reshape(-1, self.joints_num, 3)
        pred_motion_joints2 = pred_motion_joints[..., 1:2, :, :].reshape(-1, self.joints_num, 3)
        motion_joints1 = motion_joints[..., 0:1, :, :].reshape(-1, self.joints_num, 3)
        motion_joints2 = motion_joints[..., 1:2, :, :].reshape(-1, self.joints_num, 3)
        
        pred_distance_matrix = torch.cdist(pred_motion_joints1.contiguous(), pred_motion_joints2)
        tgt_distance_matrix = torch.cdist(motion_joints1.contiguous(), motion_joints2)
        
        pred_distance_matrix = pred_distance_matrix.reshape(pred_distance_matrix.shape[0], -1).reshape(self.B, self.T, self.joints_num*self.joints_num) # B*T, njoints=22, 22 -> B, T, 484
        tgt_distance_matrix = tgt_distance_matrix.reshape(pred_distance_matrix.shape[0], -1).reshape(self.B, self.T, self.joints_num*self.joints_num)
        
        dm_mask = (pred_distance_matrix < thresh_pred).float()
        dm_tgt_mask = (tgt_distance_matrix < thresh_tgt).float()
        
        dm_loss = (self.l1_criterion(pred_distance_matrix, tgt_distance_matrix) * dm_mask).sum() / (dm_mask.sum() + 1.e-7)
        dm_tgt_loss = (self.l1_criterion(pred_distance_matrix, torch.zeros_like(tgt_distance_matrix)) * dm_tgt_mask).sum()/ (dm_tgt_mask.sum() + 1.e-7)
        
        return dm_loss + dm_tgt_loss

   
    def calc_weight_dm_loss(self, motion_joints, pred_motion_joints, thresh_pred=1, thresh_tgt=0.1):

        pred_motion_joints1 = pred_motion_joints[..., 0:1, :, :].reshape(-1, self.joints_num, 3)[...,self.select_joint_index,:]
        pred_motion_joints2 = pred_motion_joints[..., 1:2, :, :].reshape(-1, self.joints_num, 3)[...,self.select_joint_index,:]
        motion_joints1 = motion_joints[..., 0:1, :, :].reshape(-1, self.joints_num, 3)[...,self.select_joint_index,:]
        motion_joints2 = motion_joints[..., 1:2, :, :].reshape(-1, self.joints_num, 3)[...,self.select_joint_index,:]
        
        pred_distance_matrix = torch.cdist(pred_motion_joints1.contiguous(), pred_motion_joints2)
        tgt_distance_matrix = torch.cdist(motion_joints1.contiguous(), motion_joints2)
        
        pred_distance_matrix = pred_distance_matrix.reshape(pred_distance_matrix.shape[0], -1).reshape(self.B, self.T, len(self.select_joint_index)*len(self.select_joint_index)) # B*T, njoints=22, 22 -> B, T, 484
        tgt_distance_matrix = tgt_distance_matrix.reshape(pred_distance_matrix.shape[0], -1).reshape(self.B, self.T, len(self.select_joint_index)*len(self.select_joint_index))
        
        dm_mask = (pred_distance_matrix < thresh_pred).float()
        dm_tgt_mask = (tgt_distance_matrix < thresh_tgt).float()
        
        dm_loss = (self.l1_criterion(pred_distance_matrix, tgt_distance_matrix) * dm_mask).sum() / (dm_mask.sum() + 1.e-7)
        dm_tgt_loss = (self.l1_criterion(pred_distance_matrix, torch.zeros_like(tgt_distance_matrix)) * dm_tgt_mask).sum()/ (dm_tgt_mask.sum() + 1.e-7)
        
        return dm_loss + dm_tgt_loss


    def calc_ro_loss(self, motion_joints, pred_motion_joints):

        r_hip, l_hip, sdr_r, sdr_l = face_joint_indx
        across = pred_motion_joints[..., r_hip, :] - pred_motion_joints[..., l_hip, :]
        across = across / across.norm(dim=-1, keepdim=True)
        across_gt = motion_joints[..., r_hip, :] - motion_joints[..., l_hip, :]
        across_gt = across_gt / across_gt.norm(dim=-1, keepdim=True)

        y_axis = torch.zeros_like(across)
        y_axis[..., 1] = 1

        forward = torch.cross(y_axis, across, axis=-1)
        forward = forward / forward.norm(dim=-1, keepdim=True)
        forward_gt = torch.cross(y_axis, across_gt, axis=-1)
        forward_gt = forward_gt / forward_gt.norm(dim=-1, keepdim=True)

        pred_relative_rot = qbetween(forward[..., 0, :], forward[..., 1, :])
        tgt_relative_rot = qbetween(forward_gt[..., 0, :], forward_gt[..., 1, :])

        ro_loss = self.l1_criterion(pred_relative_rot[..., [0, 2]],
                                    tgt_relative_rot[..., [0, 2]]).mean()

        return ro_loss
    def calc_period_temporal_dm_loss(self, motion_joints, pred_motion_joints, base_thresh_pred=1.0, base_thresh_tgt=0.1):
        """
        Calculate distance matrix loss with temporal smoothing and periodic thresholds
        Args:
            motion_joints: Ground truth motion [B,T,2,J,3]
            pred_motion_joints: Predicted motion [B,T,2,J,3]
        Returns:
            total_loss: Combined loss of distance matrix, target, and temporal components
        """
        # Input validation
        assert motion_joints.dim() == 5, f"Expected 5D input [B,T,2,J,3], got shape {motion_joints.shape}"
        assert motion_joints.shape == pred_motion_joints.shape, f"Shape mismatch: {motion_joints.shape} vs {pred_motion_joints.shape}"
        
        B, T = motion_joints.shape[:2]
        self.B = B
        self.T = T

        # 1. Extract joint positions and calculate distance matrices

        pred_motion_joints1 = pred_motion_joints[..., 0:1, :, :].reshape(-1, self.joints_num, 3)
        pred_motion_joints2 = pred_motion_joints[..., 1:2, :, :].reshape(-1, self.joints_num, 3)
        motion_joints1 = motion_joints[..., 0:1, :, :].reshape(-1, self.joints_num, 3)
        motion_joints2 = motion_joints[..., 1:2, :, :].reshape(-1, self.joints_num, 3)
        
        pred_distance_matrix = torch.cdist(pred_motion_joints1.contiguous(), pred_motion_joints2)
        tgt_distance_matrix = torch.cdist(motion_joints1.contiguous(), motion_joints2)
        # Reshape to [B,T,J*J]
        pred_distance_matrix = pred_distance_matrix.reshape(B, T, -1)
        tgt_distance_matrix = tgt_distance_matrix.reshape(B, T, -1)
        # Apply temporal smoothing
        pred_distance_smooth = pred_distance_matrix

        # Calculate periodic thresholds
        t = torch.arange(T, device=motion_joints.device).float()
        
        # Combine multiple periodic components with safe normalization
        thresh_scale = (1.0 + 
                    0.15 * torch.sin(2 * torch.pi * t / max(self.T, 1e-6)) +
                    0.1 * torch.sin(2 * torch.pi * t / 20.0) +
                    0.05 * torch.sin(2 * torch.pi * t / 10.0))
        
        thresh_scale = thresh_scale.unsqueeze(0).unsqueeze(-1)
        
        # Apply dynamic thresholds with safety checks
        dynamic_thresh_pred = torch.clamp(base_thresh_pred * thresh_scale, min=0.1)
        dynamic_thresh_tgt = torch.clamp(base_thresh_tgt * thresh_scale, min=0.01)

        # Calculate masks with dynamic thresholds
        dm_mask = (pred_distance_smooth < dynamic_thresh_pred).float()
        dm_tgt_mask = (tgt_distance_matrix < dynamic_thresh_tgt).float()

        # Calculate temporal continuity loss
        temporal_loss = torch.abs(pred_distance_smooth[:, 1:] - pred_distance_smooth[:, :-1]).mean()
        
        # Calculate masked reconstruction losses with safe denominators
        dm_loss = (self.l1_criterion(pred_distance_smooth, tgt_distance_matrix) * dm_mask).sum()
        dm_mask_sum = dm_mask.sum() + 1e-7
        dm_loss = dm_loss / dm_mask_sum

        dm_tgt_loss = (self.l1_criterion(pred_distance_matrix, torch.zeros_like(tgt_distance_matrix)) * dm_tgt_mask).sum()
        dm_tgt_mask_sum = dm_tgt_mask.sum() + 1e-7
        dm_tgt_loss = dm_tgt_loss / dm_tgt_mask_sum

        # Combine losses with safety checks
        total_loss = dm_loss + dm_tgt_loss + 0.1 * temporal_loss
        
        # Final validation
        if torch.isnan(total_loss) or torch.isinf(total_loss):
            print(f"Warning: Invalid loss values detected - dm_loss: {dm_loss}, dm_tgt_loss: {dm_tgt_loss}, temporal_loss: {temporal_loss}")
            total_loss = torch.tensor(0.0, device=total_loss.device)
        
        return total_loss
    def calc_BVH_penetration_loss(self, motion_joints, pred_motion_joints):
        """
        Calculate penetration loss using AABB and BVH with GT reference
        Args:
            motion_joints: Ground truth motion [B,T,2,J,3]
            pred_motion_joints: Predicted motion [B,T,2,J,3]
        Returns:
            total_loss: Combined penetration loss with GT reference
        """
        B, T = motion_joints.shape[:2]
        
        # Define BVH bone chains
        bvh_chains = {
            'right_leg': [0, 2, 5, 8, 11],
            'left_leg': [0, 1, 4, 7, 10],
            'spine': [0, 3, 6, 9, 12, 15],
            'right_arm': [9, 14, 17, 19, 21],
            'left_arm': [9, 13, 16, 18, 20]
        }

        # Extract joints
        pred1 = pred_motion_joints[..., 0, :, :]  # [B,T,J,3]
        pred2 = pred_motion_joints[..., 1, :, :]  # [B,T,J,3]
        gt1 = motion_joints[..., 0, :, :]         # [B,T,J,3]
        gt2 = motion_joints[..., 1, :, :]         # [B,T,J,3]

        total_pen_loss = 0
        total_mse_loss = 0
        
        for chain1_name, joints1 in bvh_chains.items():
            for chain2_name, joints2 in bvh_chains.items():
                if chain1_name == chain2_name:
                    continue
                
                # Get chain joints
                pred_chain1 = pred1[..., joints1, :]  # [B,T,C1,3]
                pred_chain2 = pred2[..., joints2, :]  # [B,T,C2,3]
                gt_chain1 = gt1[..., joints1, :]      # [B,T,C1,3]
                gt_chain2 = gt2[..., joints2, :]      # [B,T,C2,3]

                # 1. AABB Quick Test
                # Calculate bounding boxes
                pred_min1 = torch.min(pred_chain1, dim=2)[0]  # [B,T,3]
                pred_max1 = torch.max(pred_chain1, dim=2)[0]  # [B,T,3]
                pred_min2 = torch.min(pred_chain2, dim=2)[0]  # [B,T,3]
                pred_max2 = torch.max(pred_chain2, dim=2)[0]  # [B,T,3]

                # AABB overlap test
                overlap = (pred_min1[..., 0] < pred_max2[..., 0]) & \
                        (pred_max1[..., 0] > pred_min2[..., 0]) & \
                        (pred_min1[..., 1] < pred_max2[..., 1]) & \
                        (pred_max1[..., 1] > pred_min2[..., 1]) & \
                        (pred_min1[..., 2] < pred_max2[..., 2]) & \
                        (pred_max1[..., 2] > pred_min2[..., 2])

                # 2. Detailed Distance Calculation for Overlapping Regions
                overlap_mask = overlap.float().unsqueeze(-1).unsqueeze(-1)  # [B,T,1,1]

                # Calculate distances only for overlapping regions
                pred_dists = torch.cdist(pred_chain1, pred_chain2) * overlap_mask  # [B,T,C1,C2]
                gt_dists = torch.cdist(gt_chain1, gt_chain2) * overlap_mask       # [B,T,C1,C2]

                # Get minimum distances (masked)
                pred_min_dist = (pred_dists * overlap_mask).view(B, T, -1).min(dim=-1)[0]  # [B,T]
                gt_min_dist = (gt_dists * overlap_mask).view(B, T, -1).min(dim=-1)[0]      # [B,T]

                # Calculate penetration threshold based on GT
                min_allowed_dist = gt_min_dist * 0.9  # 90% of GT distance

                # Calculate penetration loss (no need for additional overlap.float())
                pen_mask = (pred_min_dist < min_allowed_dist).float()
                chain_pen_loss = torch.nn.functional.relu(
                    min_allowed_dist - pred_min_dist
                ) * pen_mask * overlap.float()  # Use original overlap here

                # Calculate MSE between predicted and GT distances
                chain_mse_loss = torch.nn.functional.mse_loss(
                    pred_min_dist,
                    gt_min_dist,
                    reduction='none'
                )

                # Weight based on chain types
                weight = 1.0
                if 'spine' in chain1_name or 'spine' in chain2_name:
                    weight = 2.0
                elif ('arm' in chain1_name and 'leg' in chain2_name) or \
                    ('leg' in chain1_name and 'arm' in chain2_name):
                    weight = 0.5

                # Add weighted losses
                total_pen_loss += chain_pen_loss.mean() * weight
                total_mse_loss += chain_mse_loss.mean() * weight

        # Combine penetration and MSE losses
        final_loss = 0.7 * total_pen_loss + 0.3 * total_mse_loss
        
        return final_loss
    def compute_min_bone_distances(self,chain1, chain2):
        """
        Compute minimum distance between all line segments in chain1 and chain2.
        Args:
            chain1: [B,T,L1,3] joints forming L1-1 bones
            chain2: [B,T,L2,3] joints forming L2-1 bones
        Returns:
            min_dists: [B,T] minimum bone-to-bone distance
        """
        B, T, L1, _ = chain1.shape
        L2 = chain2.shape[2]

        seg1_start = chain1[:, :, :-1, :]
        seg1_end = chain1[:, :, 1:, :]
        seg2_start = chain2[:, :, :-1, :]
        seg2_end = chain2[:, :, 1:, :]

        s1 = seg1_start.unsqueeze(3)
        e1 = seg1_end.unsqueeze(3)
        s2 = seg2_start.unsqueeze(2)
        e2 = seg2_end.unsqueeze(2)

        u = e1 - s1
        v = e2 - s2
        w0 = s1 - s2

        a = (u * u).sum(-1)
        b = (u * v).sum(-1)
        c = (v * v).sum(-1)
        d = (u * w0).sum(-1)
        e = (v * w0).sum(-1)

        denom = a * c - b * b + 1e-6
        sc = (b * e - c * d) / denom
        tc = (a * e - b * d) / denom

        sc = sc.clamp(0.0, 1.0)
        tc = tc.clamp(0.0, 1.0)

        closest_point1 = s1 + sc.unsqueeze(-1) * u
        closest_point2 = s2 + tc.unsqueeze(-1) * v

        dists = ((closest_point1 - closest_point2) ** 2).sum(-1)
        min_dists = dists.view(B, T, -1).min(dim=-1)[0]
        return min_dists.sqrt()
    def calc_dm_loss_gauss(self, motion_joints, pred_motion_joints, sigma=0.5, tgt_sigma=0.05):
        """
        距离矩阵损失（带高斯平滑惩罚项）
        """
        # 提取两人的骨架序列
        pred_motion_joints1 = pred_motion_joints[..., 0:1, :, :].reshape(-1, self.joints_num, 3)
        pred_motion_joints2 = pred_motion_joints[..., 1:2, :, :].reshape(-1, self.joints_num, 3)
        motion_joints1 = motion_joints[..., 0:1, :, :].reshape(-1, self.joints_num, 3)
        motion_joints2 = motion_joints[..., 1:2, :, :].reshape(-1, self.joints_num, 3)

        # 计算预测与GT的距离矩阵
        pred_distance_matrix = torch.cdist(pred_motion_joints1.contiguous(), pred_motion_joints2)  # [B*T, J, J]
        tgt_distance_matrix = torch.cdist(motion_joints1.contiguous(), motion_joints2)

        # reshape 回 [B, T, J*J]
        pred_distance_matrix = pred_distance_matrix.reshape(pred_distance_matrix.shape[0], -1).reshape(self.B, self.T, -1)
        tgt_distance_matrix = tgt_distance_matrix.reshape(pred_distance_matrix.shape[0], -1).reshape(self.B, self.T, -1)

        # 计算高斯权重（小距离→高权重）
        pred_gauss_weights = torch.exp(-pred_distance_matrix**2 / (2 * sigma**2))   # [B, T, J*J]
        tgt_gauss_weights = torch.exp(-tgt_distance_matrix**2 / (2 * tgt_sigma**2)) # GT中贴近对应点也应预测为贴近

        # 高斯加权损失项
        dm_loss = (self.l1_criterion(pred_distance_matrix, tgt_distance_matrix) * pred_gauss_weights).sum() / (pred_gauss_weights.sum() + 1e-7)
        dm_tgt_loss = (self.l1_criterion(pred_distance_matrix, torch.zeros_like(tgt_distance_matrix)) * tgt_gauss_weights).sum() / (tgt_gauss_weights.sum() + 1e-7)

        return dm_loss + dm_tgt_loss


    def calc_MGIC_penetration_loss(self,motion_joints, pred_motion_joints):
        """
        Calculate penetration loss using AABB and bone-to-bone minimum distances.
        Only computes distances when AABB overlap is detected.
        Args:
            motion_joints: Ground truth motion [B,T,2,J,3]
            pred_motion_joints: Predicted motion [B,T,2,J,3]
        Returns:
            total_loss: Combined penetration loss with GT reference
        """
        B, T = motion_joints.shape[:2]

        bvh_chains = {
            'right_leg': [0, 2, 5, 8, 11],
            'left_leg': [0, 1, 4, 7, 10],
            'spine': [0, 3, 6, 9, 12, 15],
            'right_arm': [9, 14, 17, 19, 21],
            'left_arm': [9, 13, 16, 18, 20]
        }

        pred1 = pred_motion_joints[..., 0, :, :]
        pred2 = pred_motion_joints[..., 1, :, :]
        gt1 = motion_joints[..., 0, :, :]
        gt2 = motion_joints[..., 1, :, :]

        total_pen_loss = 0
        total_mse_loss = 0

        for chain1_name, joints1 in bvh_chains.items():
            for chain2_name, joints2 in bvh_chains.items():
                pred_chain1 = pred1[..., joints1, :]
                pred_chain2 = pred2[..., joints2, :]
                gt_chain1 = gt1[..., joints1, :]
                gt_chain2 = gt2[..., joints2, :]

                pred_min1 = pred_chain1.min(dim=2)[0]
                pred_max1 = pred_chain1.max(dim=2)[0]
                pred_min2 = pred_chain2.min(dim=2)[0]
                pred_max2 = pred_chain2.max(dim=2)[0]

                overlap = (pred_min1[..., 0] < pred_max2[..., 0]) & \
                        (pred_max1[..., 0] > pred_min2[..., 0]) & \
                        (pred_min1[..., 1] < pred_max2[..., 1]) & \
                        (pred_max1[..., 1] > pred_min2[..., 1]) & \
                        (pred_min1[..., 2] < pred_max2[..., 2]) & \
                        (pred_max1[..., 2] > pred_min2[..., 2])

                pred_min_dist = torch.zeros(B, T, device=pred1.device)
                gt_min_dist = torch.zeros(B, T, device=gt1.device)

                if overlap.any():
                    mask_idx = overlap.nonzero(as_tuple=True)
                    pred_chain1_ = pred_chain1[mask_idx]
                    pred_chain2_ = pred_chain2[mask_idx]
                    gt_chain1_ = gt_chain1[mask_idx]
                    gt_chain2_ = gt_chain2[mask_idx]

                    pred_d = self.compute_min_bone_distances(
                        pred_chain1_.unsqueeze(1), pred_chain2_.unsqueeze(1)
                    ).squeeze(1)
                    gt_d = self.compute_min_bone_distances(
                        gt_chain1_.unsqueeze(1), gt_chain2_.unsqueeze(1)
                    ).squeeze(1)

                    pred_min_dist[mask_idx] = pred_d
                    gt_min_dist[mask_idx] = gt_d

                min_allowed_dist = gt_min_dist * 0.9
                pen_mask = (pred_min_dist < min_allowed_dist).float()
                chain_pen_loss = F.relu(min_allowed_dist - pred_min_dist) * pen_mask * overlap.float()
                chain_mse_loss = F.mse_loss(pred_min_dist, gt_min_dist, reduction='none')

                # weight = 1.0
                # if 'spine' in chain1_name or 'spine' in chain2_name:
                #     weight = 2.0
                # elif ('arm' in chain1_name and 'leg' in chain2_name) or \
                #      ('leg' in chain1_name and 'arm' in chain2_name):
                #     weight = 0.5

                total_pen_loss += chain_pen_loss.mean()
                total_mse_loss += chain_mse_loss.mean()

        return total_pen_loss + total_mse_loss



    def forward(self, motion1, motion2, pred_motion1, pred_motion2):
        B, T = motion1.shape[:2]
        self.B = B
        self.T = T
        
        if self.dataset_name == 'interhuman':
            motions = torch.cat([motion1.unsqueeze(-2), motion2.unsqueeze(-2)], dim=-2)
            motions = self.normalizer.backward(motions)
            
            pred_motion = torch.cat([pred_motion1.unsqueeze(-2), pred_motion2.unsqueeze(-2)], dim=-2)
            pred_motion = self.normalizer.backward(pred_motion)
            
            pred_motion_joints = pred_motion[..., :self.joints_num * 3].reshape(B, T, -1, self.joints_num, 3)
            motion_joints = motions[..., :self.joints_num * 3].reshape(B, T, -1, self.joints_num, 3)
        elif self.dataset_name == 'interx':
            motion_joints = torch.cat([motion1.unsqueeze(2), motion2.unsqueeze(2)], dim=2)
            pred_motion_joints = torch.cat([pred_motion1.unsqueeze(2), pred_motion2.unsqueeze(2)], dim=2)
        
        ro_loss = self.calc_ro_loss(motion_joints, pred_motion_joints)
        if self.args.DM_select == "DM":
            dm_loss = self.calc_dm_loss(motion_joints, pred_motion_joints)
        elif self.args.DM_select =="WDM":
            dm_loss = self.calc_weight_dm_loss(motion_joints, pred_motion_joints)
        elif self.args.DM_select == "PTDM":
            dm_loss = self.calc_period_temporal_dm_loss(motion_joints, pred_motion_joints)
        elif self.args.DM_select == "GDM":
            dm_loss = self.calc_dm_loss_gauss(motion_joints, pred_motion_joints)
        else:
            dm_loss = torch.tensor(0.0).to(motion_joints.device)

        if self.args.Pen_select == "BVH":
            pen_loss = self.calc_BVH_penetration_loss(motion_joints, pred_motion_joints)
        elif self.args.Pen_select == "MGIC":
            pen_loss = self.args.lambda_PPL*self.calc_MGIC_penetration_loss(motion_joints, pred_motion_joints)+self.args.lambda_GDM*self.calc_dm_loss_gauss(motion_joints, pred_motion_joints)
        else:
            pen_loss = torch.tensor(0.0).to(motion_joints.device)
        return dm_loss, ro_loss,pen_loss
    




        # 2. Kalman Filter for temporal smoothing
        # def kalman_smooth(x):
        #     """Kalman smoother for temporal sequences"""
        #     # State space model parameters
        #     A = torch.eye(2, device=x.device)  # State transition matrix
        #     H = torch.tensor([[1.0, 0.0]], device=x.device)  # Measurement matrix
        #     Q = torch.eye(2, device=x.device) * 0.001  # Process noise
        #     R = torch.tensor(0.1, device=x.device)  # Measurement noise
            
        #     batch_size, seq_len, feat_dim = x.shape
        #     smoothed = torch.zeros_like(x)
            
        #     # For each batch and feature
        #     for b in range(batch_size):
        #         for d in range(feat_dim):
        #             # Initialize
        #             state = torch.zeros(2, device=x.device)
        #             P = torch.eye(2, device=x.device) * 100
                    
        #             # Forward pass
        #             states, covs = [], []
        #             for t in range(seq_len):
        #                 # Predict
        #                 state = A @ state
        #                 P = A @ P @ A.T + Q
                        
        #                 # Update
        #                 K = P @ H.T / (H @ P @ H.T + R)
        #                 state = state + K @ (x[b,t,d:d+1] - H @ state)
        #                 P = (torch.eye(2, device=x.device) - K @ H) @ P
                        
        #                 states.append(state)
        #                 covs.append(P)
                    
        #             # Backward smoothing
        #             smoothed_state = states[-1]
        #             smoothed[b,-1,d] = smoothed_state[0]
                    
        #             for t in range(seq_len-2, -1, -1):
        #                 # RTS smoother equations
        #                 J = covs[t] @ A.T @ torch.inverse(A @ covs[t] @ A.T + Q)
        #                 smoothed_state = states[t] + J @ (smoothed_state - A @ states[t])
        #                 smoothed[b,t,d] = smoothed_state[0]
                        
        #     return smoothed